{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a458613f",
   "metadata": {},
   "source": [
    "# SFT_GRPO_1: Can LLM master Wordle?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e083d05",
   "metadata": {},
   "source": [
    "Start by load dependencies and setting up the Predibase client, which you'll use to call both base and finetuned models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed682fe5-6329-42da-9e5e-ce019aba8c77",
   "metadata": {
    "height": 234
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "\n",
    "# Initialize client to query Qwen2.5 7B Instruct on Predibase using the OpenAI client.\n",
    "\n",
    "_ = load_dotenv()\n",
    "\n",
    "client = OpenAI(\n",
    "    base_url=os.environ[\"PREDIBASE_MODEL_QWEN_URL\"], # Qwen 2.5 7B Instruct\n",
    "    api_key=os.environ[\"PREDIBASE_API_KEY\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b478e01c",
   "metadata": {},
   "source": [
    "Load the model tokenizer from Hugging face:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84308d60-a3b3-45c2-8918-ddde153a9202",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "base_model_id = \"Qwen/Qwen2.5-7B-Instruct\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(base_model_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fa5c160",
   "metadata": {},
   "source": [
    "## Setting up prompts to play Wordle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da4482b0-cc4e-4b77-9a04-b3401a824c9d",
   "metadata": {
    "height": 506
   },
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = \"\"\"\n",
    "You are playing Wordle, a word-guessing game.\n",
    "\n",
    "### Game Rules:\n",
    "- You have **6 tries** to guess a secret **5-letter** word.\n",
    "- Each guess must be a valid **5-letter English word**.\n",
    "- After each guess, you will receive feedback indicating how close \n",
    "your guess was.\n",
    "\n",
    "### Feedback Format:\n",
    "Each letter in your guess will receive one of three symbols:\n",
    "1. ✓ : The letter is in the word and in the CORRECT position.\n",
    "2. - : The letter is in the word but in the WRONG position.\n",
    "3. x : The letter is NOT in the word.\n",
    "\n",
    "### Example:\n",
    "Secret Word: BRISK\n",
    "\n",
    "Guess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)\n",
    "Guess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)\n",
    "Guess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)\n",
    "\n",
    "### Response Format:\n",
    "Think through the problem and feedback step by step. Make sure to \n",
    "first add your step by step thought process within <think> </think> \n",
    "tags. Then, return your guessed word in the following format: \n",
    "<guess> guessed-word </guess>.\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc165b8b-f396-471d-9d84-16d964ea0cb2",
   "metadata": {
    "height": 370
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from enum import Enum\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class LetterFeedback(Enum):\n",
    "    CORRECT = \"✓\"\n",
    "    WRONG_POS = \"-\"\n",
    "    WRONG_LETTER = \"x\"\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class GuessWithFeedback:\n",
    "    guess: str\n",
    "    feedback: List[LetterFeedback]\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        \"\"\"Returns a readable string showing the guess alongside\n",
    "        its letter-by-letter feedback.\"\"\"\n",
    "        feedback_str = \" \".join(f\"{letter}({fb.value})\" for letter, fb in zip(self.guess, self.feedback))\n",
    "        return f\"{self.guess} → Feedback: {feedback_str}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ac5b76-c834-4674-9e0b-484a2219ec69",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "def render_user_prompt(past_guesses: List[GuessWithFeedback]) -> str:\n",
    "    \"\"\"Creates a user-facing prompt that includes past guesses \n",
    "    and their feedback.\"\"\"\n",
    "    prompt = \"Make a new 5-letter word guess.\"\n",
    "    if past_guesses:\n",
    "        prompt += \"\\n\\nHere is some previous feedback:\"\n",
    "        for i, guess in enumerate(past_guesses):\n",
    "            prompt += f\"\\nGuess {i+1}: {guess}\"\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2386034-ae0b-4a16-ad38-83b01982d97b",
   "metadata": {
    "height": 387
   },
   "outputs": [],
   "source": [
    "def render_prompt(past_guesses: List[GuessWithFeedback]):\n",
    "    \"\"\"Formats a full chat prompt using a system message, user \n",
    "    prompt, and assistant preamble to start the model's \n",
    "    step-by-step reasoning.\"\"\"\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": SYSTEM_PROMPT\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": render_user_prompt(past_guesses)\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"assistant\",\n",
    "            \"content\": \"Let me solve this step by step.\\n<think>\"\n",
    "        }\n",
    "    ]\n",
    "\n",
    "    return tokenizer.apply_chat_template(\n",
    "        messages, tokenize=False, continue_final_message=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8aca9bc-a37f-474a-aa43-deedea322097",
   "metadata": {
    "height": 370
   },
   "outputs": [],
   "source": [
    "def generate_stream(prompt: str, adapter_id: str = \"\") -> str:\n",
    "    \"\"\"Streams a model-generated response from a prompt in \n",
    "    real-time and prints it as it arrives.\"\"\"\n",
    "    response = client.completions.create(\n",
    "        model=adapter_id,\n",
    "        prompt=prompt,\n",
    "        # Produce deterministic responses for evaluation\n",
    "        temperature=0.0, \n",
    "        max_tokens=2048,\n",
    "        stream=True,\n",
    "    )\n",
    "    \n",
    "    completion = \"\"\n",
    "    for chunk in response:\n",
    "        if chunk.choices[0].text is not None:\n",
    "            content = chunk.choices[0].text\n",
    "            print(content, end=\"\", flush=True)\n",
    "            completion += content\n",
    "    print()\n",
    "\n",
    "    return completion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ff5c34e",
   "metadata": {},
   "source": [
    "## Play a single turn of Wordle\n",
    "\n",
    "Start by setting up the prompt with feedback for two prior guesses:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c4ea278-1bd1-433d-9f7c-e73da319ad3d",
   "metadata": {
    "height": 404
   },
   "outputs": [],
   "source": [
    "secret_word = \"CRAFT\"\n",
    "\n",
    "past_guesses = [\n",
    "    GuessWithFeedback(\n",
    "        \"CRANE\", [\n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.WRONG_LETTER, \n",
    "            LetterFeedback.WRONG_LETTER,\n",
    "        ]),\n",
    "    GuessWithFeedback(\n",
    "        \"CRASH\", [\n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.CORRECT, \n",
    "            LetterFeedback.WRONG_LETTER, \n",
    "            LetterFeedback.WRONG_LETTER,\n",
    "        ]),\n",
    "]\n",
    "\n",
    "prompt = render_prompt(past_guesses)\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f5c6318",
   "metadata": {},
   "source": [
    "Prompt the base model and examine its response:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c08acaf0-53c8-439d-a30d-939ee43f9283",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "base_completion = generate_stream(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a2bd44d",
   "metadata": {},
   "source": [
    "Now prompt the RFT model and see how it responds:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f03f6b1c",
   "metadata": {},
   "source": [
    "<div style=\"background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> <b>NOTE:</b> The cell below loads a model that has been tuned using RFT to play wordle.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f89de204-6e6f-4c85-93cb-98f4003f58b4",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# prompt fine-tuned model\n",
    "ft_completion = generate_stream(prompt, adapter_id=\"wordle-dlai/2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a2ef51",
   "metadata": {},
   "source": [
    "## Playing a game of wordle\n",
    "\n",
    "Start by setting up a helper function that gets feedback on a guess:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0bb11f-d1f9-46dd-a3b8-e54ce34e49d7",
   "metadata": {
    "height": 234
   },
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def get_feedback(guess: str, secret_word: str) -> List[LetterFeedback]:\n",
    "    valid_letters = set(secret_word)\n",
    "    feedback = []\n",
    "    for letter, secret_letter in zip(guess, secret_word):\n",
    "        if letter == secret_letter:\n",
    "            feedback.append(LetterFeedback.CORRECT)\n",
    "        elif letter in valid_letters:\n",
    "            feedback.append(LetterFeedback.WRONG_POS)\n",
    "        else:\n",
    "            feedback.append(LetterFeedback.WRONG_LETTER)\n",
    "    return feedback"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ab8fff9",
   "metadata": {},
   "source": [
    "Now create a `next_turn` function that incorporates feedback on a guess into the prompt to the LLM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c755edc-3484-4f82-bf57-84f7aff51e5e",
   "metadata": {
    "height": 438
   },
   "outputs": [],
   "source": [
    "def next_turn(\n",
    "    past_guesses: List[GuessWithFeedback], \n",
    "    secret_word: str, \n",
    "    adapter_id = \"\"\n",
    "):\n",
    "    prompt = render_prompt(past_guesses)\n",
    "    completion = generate_stream(prompt, adapter_id)\n",
    "    match = re.search(\n",
    "        r\"<guess>\\s*(.*?)\\s*</guess>\", completion, re.DOTALL\n",
    "    )\n",
    "    if not match:\n",
    "        raise RuntimeError(\"invalid guess\")\n",
    "    \n",
    "    guess = match.group(1).upper()\n",
    "    feedback = get_feedback(guess, secret_word)\n",
    "    past_guesses.append(GuessWithFeedback(guess, feedback))\n",
    "    print(\"\\n\\n\")\n",
    "    print((\"-\" * 100) + \"\\n\")\n",
    "    for past_guess in past_guesses:\n",
    "        print(past_guess)\n",
    "    \n",
    "    if guess == secret_word:\n",
    "        print(\"🎉 SUCCESS 🎉\")\n",
    "    elif len(past_guesses) >= 6:\n",
    "        print(\"❌ better luck next time... ❌\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22416f78",
   "metadata": {},
   "source": [
    "Try playing with the base model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c2abc6-2099-489c-83ec-d8cae1f872c6",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "secret_word = \"BRICK\"\n",
    "past_guesses = []\n",
    "adapter_id = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2412db7c-5e7c-4ec7-ad80-78ef86155187",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "next_turn(past_guesses, secret_word, adapter_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "722d2243",
   "metadata": {},
   "source": [
    "Now try with the finetuned model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f408fe2-f212-4fc7-8372-863d488220c1",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "secret_word = \"BRICK\"\n",
    "past_guesses = []\n",
    "adapter_id = \"wordle-dlai/2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c187f2-b4b4-4127-bff6-df8c466236de",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "next_turn(past_guesses, secret_word, adapter_id)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
